{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Human-in-the-loop Breakpoints\n",
    "\n",
    "- streams & user input injection\n",
    "\n",
    "## Approval\n",
    "\n",
    "User case.\n",
    "\n",
    "what we need:\n",
    "\n",
    "- a way to stop the execution\n",
    "- a way to resume the execution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Redefine financial advicer graph (from 03_agent.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import yfinance as yf\n",
    "from pprint import pformat\n",
    "from langchain_core.tools import Tool\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.messages import HumanMessage, SystemMessage\n",
    "from langgraph.graph import MessagesState, START, StateGraph\n",
    "from langgraph.prebuilt import tools_condition, ToolNode\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from IPython.display import Image, display\n",
    "\n",
    "\n",
    "# Defining Tools\n",
    "##################################################################################\n",
    "\n",
    "def lookup_stock_symbol(company_name: str) -> str:\n",
    "    \"\"\"\n",
    "    Converts a company name to its stock symbol using a financial API.\n",
    "\n",
    "    Parameters:\n",
    "        company_name (str): The full company name (e.g., 'Tesla').\n",
    "\n",
    "    Returns:\n",
    "        str: The stock symbol (e.g., 'TSLA') or an error message.\n",
    "    \"\"\"\n",
    "    api_url = \"https://www.alphavantage.co/query\"\n",
    "    params = {\n",
    "        \"function\": \"SYMBOL_SEARCH\",\n",
    "        \"keywords\": company_name,\n",
    "        \"apikey\": \"your_alphavantage_api_key\"\n",
    "    }\n",
    "    \n",
    "    response = requests.get(api_url, params=params)\n",
    "    data = response.json()\n",
    "    \n",
    "    if \"bestMatches\" in data and data[\"bestMatches\"]:\n",
    "        return data[\"bestMatches\"][0][\"1. symbol\"]\n",
    "    else:\n",
    "        return f\"Symbol not found for {company_name}.\"\n",
    "\n",
    "\n",
    "def fetch_stock_data_raw(stock_symbol: str) -> dict:\n",
    "    \"\"\"\n",
    "    Fetches comprehensive stock data for a given symbol and returns it as a combined dictionary.\n",
    "\n",
    "    Parameters:\n",
    "        stock_symbol (str): The stock ticker symbol (e.g., 'TSLA').\n",
    "        period (str): The period to analyze (e.g., '1mo', '3mo', '1y').\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary combining general stock info and historical market data.\n",
    "    \"\"\"\n",
    "    period = \"1mo\"\n",
    "    try:\n",
    "        stock = yf.Ticker(stock_symbol)\n",
    "\n",
    "        # Retrieve general stock info and historical market data\n",
    "        stock_info = stock.info  # Basic company and stock data\n",
    "        stock_history = stock.history(period=period).to_dict()  # Historical OHLCV data\n",
    "\n",
    "        # Combine both into a single dictionary\n",
    "        combined_data = {\n",
    "            \"stock_symbol\": stock_symbol,\n",
    "            \"info\": stock_info,\n",
    "            \"history\": stock_history\n",
    "        }\n",
    "\n",
    "        return pformat(combined_data)\n",
    "\n",
    "    except Exception as e:\n",
    "        return {\"error\": f\"Error fetching stock data for {stock_symbol}: {str(e)}\"}\n",
    "\n",
    "\n",
    "# Binding tools to the LLM\n",
    "##################################################################################\n",
    "\n",
    "# Create tool bindings with additional attributes\n",
    "lookup_stock = Tool.from_function(\n",
    "    func=lookup_stock_symbol,\n",
    "    name=\"lookup_stock_symbol\",\n",
    "    description=\"Converts a company name to its stock symbol using a financial API.\",\n",
    "    return_direct=False  # Return result to be processed by LLM\n",
    ")\n",
    "\n",
    "fetch_stock = Tool.from_function(\n",
    "    func=fetch_stock_data_raw,\n",
    "    name=\"fetch_stock_data_raw\",\n",
    "    description=\"Fetches comprehensive stock data including general info and historical market data for a given stock symbol.\",\n",
    "    return_direct=False\n",
    ")\n",
    "\n",
    "toolbox = [lookup_stock, fetch_stock]\n",
    "\n",
    "# OPENAI_API_KEY environment variable must be set\n",
    "simple_llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "llm_with_tools = simple_llm.bind_tools(toolbox)\n",
    "\n",
    "\n",
    "# Defining Agent's node\n",
    "##################################################################################\n",
    "\n",
    "# System message\n",
    "assistant_system_message = SystemMessage(content=(\"\"\"\n",
    "You are a professional financial assistant specializing in stock market analysis and investment strategies. \n",
    "Your role is to analyze stock data and provide **clear, decisive recommendations** that users can act on, \n",
    "whether they already hold the stock or are considering investing.\n",
    "\n",
    "You have access to a set of tools that can provide the data you need to analyze stocks effectively. \n",
    "Use these tools to gather relevant information such as stock symbols, current prices, historical trends, \n",
    "and key financial indicators. Your goal is to leverage these resources efficiently to generate accurate, \n",
    "actionable insights for the user.\n",
    "\n",
    "Your responses should be:\n",
    "- **Concise and direct**, summarizing only the most critical insights.\n",
    "- **Actionable**, offering clear guidance on whether to buy, sell, hold, or wait for better opportunities.\n",
    "- **Context-aware**, considering both current holders and potential investors.\n",
    "- **Free of speculation**, relying solely on factual data and trends.\n",
    "\n",
    "### Response Format:\n",
    "1. **Recommendation:** Buy, Sell, Hold, or Wait.\n",
    "2. **Key Insights:** Highlight critical trends and market factors that influence the decision.\n",
    "3. **Suggested Next Steps:** What the user should do based on their current position.\n",
    "\n",
    "If the user does not specify whether they own the stock, provide recommendations for both potential buyers and current holders. Ensure your advice considers valuation, trends, and market sentiment.\n",
    "\n",
    "Your goal is to help users make informed financial decisions quickly and confidently.\n",
    "\"\"\"))\n",
    "\n",
    "# Node\n",
    "def assistant(state: MessagesState):\n",
    "   return {\"messages\": [llm_with_tools.invoke([assistant_system_message] + state[\"messages\"])]}\n",
    "\n",
    "\n",
    "# Defining Graph\n",
    "##################################################################################\n",
    "# Graph\n",
    "builder = StateGraph(MessagesState)\n",
    "\n",
    "# Define nodes: these do the work\n",
    "builder.add_node(\"assistant\", assistant)\n",
    "builder.add_node(\"tools\", ToolNode(toolbox))\n",
    "\n",
    "# Define edges: these determine how the control flow moves\n",
    "builder.add_edge(START, \"assistant\")\n",
    "builder.add_conditional_edges(\n",
    "    \"assistant\",\n",
    "    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools\n",
    "    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END\n",
    "    tools_condition,\n",
    ")\n",
    "builder.add_edge(\"tools\", \"assistant\")\n",
    "\n",
    "memory = MemorySaver()\n",
    "graph = builder.compile(checkpointer=memory)\n",
    "display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adding a breakpoint to the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph = builder.compile(interrupt_before=[\"tools\"], checkpointer=memory)\n",
    "\n",
    "display(Image(graph.get_graph(xray=True).draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's test the breakpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start a new conversation\n",
    "thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
    "\n",
    "# define intiial user request\n",
    "initial_input = {\"messages\": HumanMessage(content=\"Should I invest in Tesla stocks?\")}\n",
    "\n",
    "# run the graph and stream in values mode\n",
    "for event in graph.stream(initial_input, thread, stream_mode=\"values\"):\n",
    "    event['messages'][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check the graph current state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = graph.get_state(thread)\n",
    "state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the next node in the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state.next"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Resuming execution\n",
    "\n",
    "### 1) Initializes a new execution flow based on the provided initial_input\n",
    "\n",
    "```python\n",
    "graph.stream(initial_input, thread, stream_mode=\"values\")\n",
    "```\n",
    "\n",
    "### 2) resumes from the latest checkpoint stored in the thread state\n",
    "\n",
    "```python\n",
    "graph.stream(None, thread, stream_mode=\"values\")\n",
    "```\n",
    "\n",
    "Passing `None` as initial_input tells LangGraph: “Continue execution from where it left off for the given `thread_id`.”\n",
    "\n",
    "By setting initial_input = `None`, you’re not reinitializing the graph but rather instructing it to continue from the last recorded execution point for the given `thread_id`. This mechanism ensures smooth, stateful interactions, making cyclic approvals and multi-step workflows possible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# as a first event LangGraph re-emits the current state (AIMessage with the tool call).\n",
    "for event in graph.stream(None, thread, stream_mode=\"values\"):\n",
    "    event['messages'][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combining all together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start a new conversation\n",
    "thread = {\"configurable\": {\"thread_id\": \"3\"}}\n",
    "\n",
    "initial_input = {\"messages\": HumanMessage(content=\"Should I invest in Tesla stocks?\")}\n",
    "\n",
    "current_state_reemit = False\n",
    "\n",
    "# Run the graph in a loop, pausing at each tool-node call\n",
    "while True:\n",
    "    should_restart = False  # Reset flag at the start of each while-loop iteration\n",
    "\n",
    "    # Run the graph until the next interruption (breakpoint before tool-node)\n",
    "    for event in graph.stream(initial_input, thread, stream_mode=\"values\"):\n",
    "        event['messages'][-1].pretty_print()\n",
    "        \n",
    "        # Check if a tool is about to be called\n",
    "        if \"tool_calls\" in event['messages'][-1].additional_kwargs:\n",
    "            if  current_state_reemit:\n",
    "                current_state_reemit = False\n",
    "                continue\n",
    "\n",
    "            else:\n",
    "                tool_name = event['messages'][-1].additional_kwargs[\"tool_calls\"][0][\"function\"][\"name\"]  # Extract the tool name\n",
    "                # Get user feedback\n",
    "                user_approval = input(f\"Do you want to call the tool {tool_name}? (yes/no): \")\n",
    "\n",
    "                if user_approval.lower() == \"yes\":\n",
    "                    initial_input = None\n",
    "                    current_state_reemit = True\n",
    "                    should_restart = True\n",
    "                else:\n",
    "                    print(f\"\\n\\nTool call {tool_name} was cancelled by user\")\n",
    "\n",
    "                break\n",
    "\n",
    "    if should_restart:\n",
    "        continue\n",
    "    \n",
    "    # Once the graph completes without hitting a tool node, break out of the loop\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Breakpoints in LangGraph API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph_sdk import get_client\n",
    "\n",
    "URL = \"http://localhost:49381\"\n",
    "client = get_client(url=URL)\n",
    "\n",
    "assistants = await client.assistants.search()\n",
    "assistants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import convert_to_messages\n",
    "\n",
    "thread = await client.threads.create()\n",
    "input_message = HumanMessage(content=\"Should I invest in Tesla stocks?\")\n",
    "\n",
    "async for event in client.runs.stream(\n",
    "            thread[\"thread_id\"], \n",
    "            assistant_id=\"b7480eb0-6390-53a5-9bc4-29bf27cbd1c4\", \n",
    "            input={\"messages\": [input_message]}, \n",
    "            stream_mode=\"values\",\n",
    "            interrupt_before=[\"tools\"],\n",
    "):\n",
    "    messages = event.data.get('messages', None)\n",
    "    if messages:\n",
    "        print(convert_to_messages(messages)[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async for event in client.runs.stream(\n",
    "            thread[\"thread_id\"], \n",
    "            assistant_id=\"b7480eb0-6390-53a5-9bc4-29bf27cbd1c4\", \n",
    "            input=None, \n",
    "            stream_mode=\"values\",\n",
    "            interrupt_before=[\"tools\"],\n",
    "):\n",
    "    messages = event.data.get('messages', None)\n",
    "    if messages:\n",
    "        print(convert_to_messages(messages)[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "async for event in client.runs.stream(\n",
    "            thread[\"thread_id\"], \n",
    "            assistant_id=\"b7480eb0-6390-53a5-9bc4-29bf27cbd1c4\", \n",
    "            input=None, \n",
    "            stream_mode=\"values\",\n",
    "            interrupt_before=[\"tools\"],\n",
    "):\n",
    "    messages = event.data.get('messages', None)\n",
    "    if messages:\n",
    "        print(convert_to_messages(messages)[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Breakpoints in Langgraph Studio\n",
    "\n",
    "Let's copy/paste `financial_advisor` but with `interrupt_before=[\"tools\"]`. (`studio/financial_advisor.py` -> `studio/financial_advisor_breakpoint.py`)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
